{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g_nWetWWd_ns"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "2pHVBk_seED1"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "cellView": "form",
        "id": "N_fMsQ-N8I7j"
      },
      "outputs": [],
      "source": [
        "#@title MIT License\n",
        "#\n",
        "# Copyright (c) 2017 François Chollet\n",
        "#\n",
        "# Permission is hereby granted, free of charge, to any person obtaining a\n",
        "# copy of this software and associated documentation files (the \"Software\"),\n",
        "# to deal in the Software without restriction, including without limitation\n",
        "# the rights to use, copy, modify, merge, publish, distribute, sublicense,\n",
        "# and/or sell copies of the Software, and to permit persons to whom the\n",
        "# Software is furnished to do so, subject to the following conditions:\n",
        "#\n",
        "# The above copyright notice and this permission notice shall be included in\n",
        "# all copies or substantial portions of the Software.\n",
        "#\n",
        "# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
        "# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
        "# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL\n",
        "# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
        "# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n",
        "# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER\n",
        "# DEALINGS IN THE SOFTWARE."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pZJ3uY9O17VN"
      },
      "source": [
        "# 保存和恢复模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "M4Ata7_wMul1"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/tutorials/keras/save_and_load\">     <img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">     在 TensorFlow.org 上查看</a></td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/keras/save_and_load.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 运行</a>   </td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/keras/save_and_load.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 Github 上查看源代码</a>   </td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/tutorials/keras/save_and_load.ipynb\" class=\"_active_edit_href\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a> </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mBdde4YJeJKF"
      },
      "source": [
        "可以在训练期间和之后保存模型进度。这意味着模型可以从停止的地方恢复，避免长时间的训练。此外，保存还意味着您可以分享您的模型，其他人可以重现您的工作。在发布研究模型和技术时，大多数机器学习从业者会分享：\n",
        "\n",
        "- 用于创建模型的代码\n",
        "- 模型的训练权重或形参\n",
        "\n",
        "共享数据有助于其他人了解模型的工作原理，并使用新数据自行尝试。\n",
        "\n",
        "小心：TensorFlow 模型是代码，对于不受信任的代码，一定要小心。请参阅 [安全使用 TensorFlow](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md) 以了解详情。\n",
        "\n",
        "### 选项\n",
        "\n",
        "根据您使用的 API，可以通过不同的方式保存 TensorFlow 模型。本指南使用 [tf.keras](https://tensorflow.google.cn/guide/keras) – 一种用于在 TensorFlow 中构建和训练模型的高级 API。建议使用本教程中使用的新的高级 `.keras` 格式来保存 Keras 对象，因为它提供了强大、高效的基于名称的保存，通常比低级或旧版格式更容易调试。如需更高级的保存或序列化工作流，尤其是那些涉及自定义对象的工作流，请参阅[保存和加载 Keras 模型指南](https://tensorflow.google.cn/guide/keras/save_and_serialize)。对于其他方式，请参阅[使用 SavedModel 格式指南](../../guide/saved_model.ipynb)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xCUREq7WXgvg"
      },
      "source": [
        "## 配置\n",
        "\n",
        "### 安装并导入"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7l0MiTOrXtNv"
      },
      "source": [
        "安装并导入Tensorflow和依赖项："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "RzIOVSdnMYyO",
        "outputId": "f06dd22b-e147-4c90-c046-7a9bc78b1890",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (6.0.1)\n",
            "Requirement already satisfied: h5py in /usr/local/lib/python3.10/dist-packages (3.9.0)\n",
            "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.10/dist-packages (from h5py) (1.25.2)\n"
          ]
        }
      ],
      "source": [
        "!pip install pyyaml h5py  # Required to save models in HDF5 format"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "7Nm7Tyb-gRt-",
        "outputId": "d1c3522f-7fbb-4a35-9835-20589fa29f64",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.15.0\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "print(tf.version.VERSION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbGsznErXWt6"
      },
      "source": [
        "### 获取示例数据集\n",
        "\n",
        "为了演示如何保存和加载权重，您将使用 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)。为了加快运行速度，请使用前 1000 个样本："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "9rGfFwE9XVwz",
        "outputId": "63d34c2a-00f4-4624-92d9-ba68e5a0d079",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
            "11490434/11490434 [==============================] - 0s 0us/step\n"
          ]
        }
      ],
      "source": [
        "(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()\n",
        "\n",
        "train_labels = train_labels[:1000]\n",
        "test_labels = test_labels[:1000]\n",
        "\n",
        "train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0\n",
        "test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "anG3iVoXyZGI"
      },
      "source": [
        "### 定义模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wynsOBfby0Pa"
      },
      "source": [
        "首先构建一个简单的序列（sequential）模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "0HZbJIjxyX1S",
        "outputId": "b8263aec-0681-4138-8410-d699b3c163c5",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"sequential\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " dense (Dense)               (None, 512)               401920    \n",
            "                                                                 \n",
            " dropout (Dropout)           (None, 512)               0         \n",
            "                                                                 \n",
            " dense_1 (Dense)             (None, 10)                5130      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 407050 (1.55 MB)\n",
            "Trainable params: 407050 (1.55 MB)\n",
            "Non-trainable params: 0 (0.00 Byte)\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "# Define a simple sequential model\n",
        "def create_model():\n",
        "  model = tf.keras.Sequential([\n",
        "    keras.layers.Dense(512, activation='relu', input_shape=(784,)),\n",
        "    keras.layers.Dropout(0.2),\n",
        "    keras.layers.Dense(10)\n",
        "  ])\n",
        "\n",
        "  model.compile(optimizer='adam',\n",
        "                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "  return model\n",
        "\n",
        "# Create a basic model instance\n",
        "model = create_model()\n",
        "\n",
        "# Display the model's architecture\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "soDE0W_KH8rG"
      },
      "source": [
        "## 在训练期间保存模型（以 checkpoints 形式保存）"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mRyd5qQQIXZm"
      },
      "source": [
        "您可以使用经过训练的模型而无需重新训练，或者在训练过程中断的情况下从离开处继续训练。`tf.keras.callbacks.ModelCheckpoint` 回调允许您在训练*期间*和*结束*时持续保存模型。\n",
        "\n",
        "### Checkpoint 回调用法\n",
        "\n",
        "创建一个只在训练期间保存权重的 `tf.keras.callbacks.ModelCheckpoint` 回调："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "IFPuhwntH8VH",
        "outputId": "89f4e365-224e-45ab-988d-3912668d5a34",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/10\n",
            "28/32 [=========================>....] - ETA: 0s - loss: 1.1993 - sparse_categorical_accuracy: 0.6685\n",
            "Epoch 1: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 1s 20ms/step - loss: 1.1330 - sparse_categorical_accuracy: 0.6850 - val_loss: 0.6854 - val_sparse_categorical_accuracy: 0.8000\n",
            "Epoch 2/10\n",
            "29/32 [==========================>...] - ETA: 0s - loss: 0.4397 - sparse_categorical_accuracy: 0.8631\n",
            "Epoch 2: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 12ms/step - loss: 0.4291 - sparse_categorical_accuracy: 0.8670 - val_loss: 0.5252 - val_sparse_categorical_accuracy: 0.8390\n",
            "Epoch 3/10\n",
            "30/32 [===========================>..] - ETA: 0s - loss: 0.2779 - sparse_categorical_accuracy: 0.9240\n",
            "Epoch 3: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 14ms/step - loss: 0.2781 - sparse_categorical_accuracy: 0.9250 - val_loss: 0.5102 - val_sparse_categorical_accuracy: 0.8370\n",
            "Epoch 4/10\n",
            "28/32 [=========================>....] - ETA: 0s - loss: 0.2153 - sparse_categorical_accuracy: 0.9531\n",
            "Epoch 4: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 15ms/step - loss: 0.2135 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.4825 - val_sparse_categorical_accuracy: 0.8430\n",
            "Epoch 5/10\n",
            "28/32 [=========================>....] - ETA: 0s - loss: 0.1604 - sparse_categorical_accuracy: 0.9654\n",
            "Epoch 5: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 14ms/step - loss: 0.1643 - sparse_categorical_accuracy: 0.9640 - val_loss: 0.4445 - val_sparse_categorical_accuracy: 0.8540\n",
            "Epoch 6/10\n",
            "30/32 [===========================>..] - ETA: 0s - loss: 0.1096 - sparse_categorical_accuracy: 0.9823\n",
            "Epoch 6: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 14ms/step - loss: 0.1104 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4150 - val_sparse_categorical_accuracy: 0.8710\n",
            "Epoch 7/10\n",
            "32/32 [==============================] - ETA: 0s - loss: 0.0809 - sparse_categorical_accuracy: 0.9890\n",
            "Epoch 7: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 15ms/step - loss: 0.0809 - sparse_categorical_accuracy: 0.9890 - val_loss: 0.4223 - val_sparse_categorical_accuracy: 0.8690\n",
            "Epoch 8/10\n",
            "28/32 [=========================>....] - ETA: 0s - loss: 0.0641 - sparse_categorical_accuracy: 0.9922\n",
            "Epoch 8: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 12ms/step - loss: 0.0629 - sparse_categorical_accuracy: 0.9930 - val_loss: 0.4281 - val_sparse_categorical_accuracy: 0.8550\n",
            "Epoch 9/10\n",
            "27/32 [========================>.....] - ETA: 0s - loss: 0.0458 - sparse_categorical_accuracy: 0.9977\n",
            "Epoch 9: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 12ms/step - loss: 0.0493 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4227 - val_sparse_categorical_accuracy: 0.8640\n",
            "Epoch 10/10\n",
            "29/32 [==========================>...] - ETA: 0s - loss: 0.0440 - sparse_categorical_accuracy: 0.9978\n",
            "Epoch 10: saving model to training_1/cp.ckpt\n",
            "32/32 [==============================] - 0s 14ms/step - loss: 0.0440 - sparse_categorical_accuracy: 0.9980 - val_loss: 0.4174 - val_sparse_categorical_accuracy: 0.8690\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<keras.src.callbacks.History at 0x7bccb44755a0>"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpoint_path = \"training_1/cp.ckpt\"\n",
        "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
        "\n",
        "# Create a callback that saves the model's weights\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\n",
        "                                                 save_weights_only=True,\n",
        "                                                 verbose=1)\n",
        "\n",
        "# Train the model with the new callback\n",
        "model.fit(train_images,\n",
        "          train_labels,\n",
        "          epochs=10,\n",
        "          validation_data=(test_images, test_labels),\n",
        "          callbacks=[cp_callback])  # Pass callback to training\n",
        "\n",
        "# This may generate warnings related to saving the state of the optimizer.\n",
        "# These warnings (and similar warnings throughout this notebook)\n",
        "# are in place to discourage outdated usage, and can be ignored."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rlM-sgyJO084"
      },
      "source": [
        "这将创建一个 TensorFlow checkpoint 文件集合，这些文件在每个 epoch 结束时更新："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "gXG5FVKFOVQ3",
        "outputId": "302054b9-3d8f-49ed-9257-9443112078a1",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "['cp.ckpt.index', 'cp.ckpt.data-00000-of-00001', 'checkpoint']"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "os.listdir(checkpoint_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wlRN_f56Pqa9"
      },
      "source": [
        "只要两个模型共享相同的架构，您就可以在它们之间共享权重。因此，当从仅权重恢复模型时，创建一个与原始模型具有相同架构的模型，然后设置其权重。\n",
        "\n",
        "现在，重新构建一个未经训练的全新模型并基于测试集对其进行评估。未经训练的模型将以机会水平执行（约 10% 的准确率）："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "Fp5gbuiaPqCT",
        "outputId": "957518bc-19b4-4f26-8db1-906462c0e012",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 2.3974 - sparse_categorical_accuracy: 0.0940 - 218ms/epoch - 7ms/step\n",
            "Untrained model, accuracy:  9.40%\n"
          ]
        }
      ],
      "source": [
        "# Create a basic model instance\n",
        "model = create_model()\n",
        "\n",
        "# Evaluate the model\n",
        "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
        "print(\"Untrained model, accuracy: {:5.2f}%\".format(100 * acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1DTKpZssRSo3"
      },
      "source": [
        "然后从 checkpoint 加载权重并重新评估："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "2IZxbwiRRSD2",
        "outputId": "70b45cdd-4477-4ca6-8c77-6e919102b745",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.4174 - sparse_categorical_accuracy: 0.8690 - 84ms/epoch - 3ms/step\n",
            "Restored model, accuracy: 86.90%\n"
          ]
        }
      ],
      "source": [
        "# Loads the weights\n",
        "model.load_weights(checkpoint_path)\n",
        "\n",
        "# Re-evaluate the model\n",
        "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
        "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bpAbKkAyVPV8"
      },
      "source": [
        "### checkpoint 回调选项\n",
        "\n",
        "回调提供了几个选项，为 checkpoint 提供唯一名称并调整 checkpoint 频率。\n",
        "\n",
        "训练一个新模型，每五个 epochs 保存一次唯一命名的 checkpoint ："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "mQF_dlgIVOvq",
        "outputId": "ed8203c5-1dd7-4996-c2c5-b7ad029356b3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "Epoch 5: saving model to training_2/cp-0005.ckpt\n",
            "\n",
            "Epoch 10: saving model to training_2/cp-0010.ckpt\n",
            "\n",
            "Epoch 15: saving model to training_2/cp-0015.ckpt\n",
            "\n",
            "Epoch 20: saving model to training_2/cp-0020.ckpt\n",
            "\n",
            "Epoch 25: saving model to training_2/cp-0025.ckpt\n",
            "\n",
            "Epoch 30: saving model to training_2/cp-0030.ckpt\n",
            "\n",
            "Epoch 35: saving model to training_2/cp-0035.ckpt\n",
            "\n",
            "Epoch 40: saving model to training_2/cp-0040.ckpt\n",
            "\n",
            "Epoch 45: saving model to training_2/cp-0045.ckpt\n",
            "\n",
            "Epoch 50: saving model to training_2/cp-0050.ckpt\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<keras.src.callbacks.History at 0x7bccb3933640>"
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Include the epoch in the file name (uses `str.format`)\n",
        "checkpoint_path = \"training_2/cp-{epoch:04d}.ckpt\"\n",
        "checkpoint_dir = os.path.dirname(checkpoint_path)\n",
        "\n",
        "batch_size = 32\n",
        "\n",
        "# Calculate the number of batches per epoch\n",
        "import math\n",
        "n_batches = len(train_images) / batch_size\n",
        "n_batches = math.ceil(n_batches)    # round up the number of batches to the nearest whole integer\n",
        "\n",
        "# Create a callback that saves the model's weights every 5 epochs\n",
        "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n",
        "    filepath=checkpoint_path,\n",
        "    verbose=1,\n",
        "    save_weights_only=True,\n",
        "    save_freq=5*n_batches)\n",
        "\n",
        "# Create a new model instance\n",
        "model = create_model()\n",
        "\n",
        "# Save the weights using the `checkpoint_path` format\n",
        "model.save_weights(checkpoint_path.format(epoch=0))\n",
        "\n",
        "# Train the model with the new callback\n",
        "model.fit(train_images,\n",
        "          train_labels,\n",
        "          epochs=50,\n",
        "          batch_size=batch_size,\n",
        "          callbacks=[cp_callback],\n",
        "          validation_data=(test_images, test_labels),\n",
        "          verbose=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1zFrKTjjavWI"
      },
      "source": [
        "现在，检查生成的检查点并选择最新检查点："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "p64q3-V4sXt0",
        "outputId": "ad6c358a-a20f-41c3-e4ad-f2be7404bca5",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "['cp-0010.ckpt.index',\n",
              " 'cp-0010.ckpt.data-00000-of-00001',\n",
              " 'cp-0005.ckpt.data-00000-of-00001',\n",
              " 'cp-0030.ckpt.index',\n",
              " 'cp-0045.ckpt.data-00000-of-00001',\n",
              " 'cp-0025.ckpt.data-00000-of-00001',\n",
              " 'cp-0035.ckpt.index',\n",
              " 'cp-0050.ckpt.data-00000-of-00001',\n",
              " 'cp-0000.ckpt.data-00000-of-00001',\n",
              " 'cp-0045.ckpt.index',\n",
              " 'cp-0000.ckpt.index',\n",
              " 'cp-0040.ckpt.data-00000-of-00001',\n",
              " 'checkpoint',\n",
              " 'cp-0005.ckpt.index',\n",
              " 'cp-0040.ckpt.index',\n",
              " 'cp-0015.ckpt.index',\n",
              " 'cp-0050.ckpt.index',\n",
              " 'cp-0015.ckpt.data-00000-of-00001',\n",
              " 'cp-0020.ckpt.data-00000-of-00001',\n",
              " 'cp-0035.ckpt.data-00000-of-00001',\n",
              " 'cp-0030.ckpt.data-00000-of-00001',\n",
              " 'cp-0020.ckpt.index',\n",
              " 'cp-0025.ckpt.index']"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "os.listdir(checkpoint_dir)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "1AN_fnuyR41H",
        "outputId": "2724fa2a-019a-477d-c921-a81e029f460c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'training_2/cp-0050.ckpt'"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "latest = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "latest"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Zk2ciGbKg561"
      },
      "source": [
        "注：默认 TensorFlow 格式只保存最近的 5 个检查点。\n",
        "\n",
        "要进行测试，请重置模型并加载最新检查点："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "3M04jyK-H3QK",
        "outputId": "6eabce1d-da9c-4537-d388-1db19f7009b6",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.5066 - sparse_categorical_accuracy: 0.8710 - 282ms/epoch - 9ms/step\n",
            "Restored model, accuracy: 87.10%\n"
          ]
        }
      ],
      "source": [
        "# Create a new model instance\n",
        "model = create_model()\n",
        "\n",
        "# Load the previously saved weights\n",
        "model.load_weights(latest)\n",
        "\n",
        "# Re-evaluate the model\n",
        "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
        "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c2OxsJOTHxia"
      },
      "source": [
        "## 这些文件是什么？"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JtdYhvWnH2ib"
      },
      "source": [
        "上述代码可将权重存储到[检查点](../../guide/checkpoint.ipynb)格式文件（仅包含二进制格式训练权重） 的合集中。检查点包含：\n",
        "\n",
        "- 一个或多个包含模型权重的分片。\n",
        "- 一个索引文件，指示哪些权重存储在哪个分片中。\n",
        "\n",
        "如果您在一台计算机上训练模型，您将获得一个具有如下后缀的分片：`.data-00000-of-00001`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S_FA-ZvxuXQV"
      },
      "source": [
        "## 手动保存权重\n",
        "\n",
        "要手动保存权重，请使用 `tf.keras.Model.save_weights`。默认情况下，`tf.keras`（尤其是 `Model.save_weights` 方法）使用扩展名为 `.ckpt` 的 TensorFlow [检查点](../../guide/checkpoint.ipynb)格式。要以扩展名为 `.h5` 的 HDF5 格式保存，请参阅[保存和加载模型](https://tensorflow.google.cn/guide/keras/save_and_serialize)指南。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "R7W5plyZ-u9X",
        "outputId": "1210336d-b130-47ba-a478-d0f77df30f03",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.5066 - sparse_categorical_accuracy: 0.8710 - 273ms/epoch - 9ms/step\n",
            "Restored model, accuracy: 87.10%\n"
          ]
        }
      ],
      "source": [
        "# Save the weights\n",
        "model.save_weights('./checkpoints/my_checkpoint')\n",
        "\n",
        "# Create a new model instance\n",
        "model = create_model()\n",
        "\n",
        "# Restore the weights\n",
        "model.load_weights('./checkpoints/my_checkpoint')\n",
        "\n",
        "# Evaluate the model\n",
        "loss, acc = model.evaluate(test_images, test_labels, verbose=2)\n",
        "print(\"Restored model, accuracy: {:5.2f}%\".format(100 * acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kOGlxPRBEvV1"
      },
      "source": [
        "## 保存整个模型\n",
        "\n",
        "调用 `tf.keras.Model.save`，将模型的架构、权重和训练配置保存在单个 `model.keras` zip 存档中。\n",
        "\n",
        "整个模型可以保存为三种不同的文件格式（新的 `.keras` 格式和两种旧格式：`SavedModel` 和 `HDF5`）。将模型保存为 `path/to/model.keras` 会自动以最新格式保存。\n",
        "\n",
        "**注意**：对于 Keras 对象，建议使用新的高级 `.keras` 格式进行更丰富的基于名称的保存和重新加载，这样更易于调试。现有代码继续支持低级 SavedModel 格式和旧版 H5 格式。\n",
        "\n",
        "您可以通过以下方式切换到 SavedModel 格式：\n",
        "\n",
        "- 将 `save_format='tf'` 传递到 `save()`\n",
        "- 传递不带扩展名的文件名\n",
        "\n",
        "您可以通过以下方式切换到 H5 格式：\n",
        "\n",
        "- 将 `save_format='h5'` 传递到 `save()`\n",
        "- 传递以 `.h5` 结尾的文件名\n",
        "\n",
        "保存全功能模型会非常有用，您可以在 TensorFlow.js（[Saved Model](https://tensorflow.google.cn/js/tutorials/conversion/import_saved_model)、[HDF5](https://tensorflow.google.cn/js/tutorials/conversion/import_keras)）中加载它们，然后在网络浏览器中训练和运行，或者使用 TensorFlow Lite（[Saved Model](https://tensorflow.google.cn/lite/models/convert/#convert_a_savedmodel_recommended_)、[HDF5](https://tensorflow.google.cn/lite/models/convert/#convert_a_keras_model_)）转换它们以在移动设备上运行\n",
        "\n",
        "*自定义对象（例如，子类化模型或层）在保存和加载时需要特别注意。请参阅下面的**保存自定义对象**部分。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0fRGnlHMrkI7"
      },
      "source": [
        "### 新的高级 `.keras` 格式"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eqO8jj7GsCDn"
      },
      "source": [
        "以 `.keras` 扩展名标记的新 Keras v3 保存格式是一种更简单、更高效的格式，它实现了基于名称的保存，从 Python 的角度确保您加载的内容与您保存的内容完全相同。这使得调试更容易，并且它是 Keras 的推荐格式。\n",
        "\n",
        "下面的部分说明了如何以 `.keras` 格式保存和恢复模型。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "3f55mAXwukUX",
        "outputId": "4952b709-d3e8-4fa5-f9a6-301aea4cd1f8",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/5\n",
            "32/32 [==============================] - 1s 12ms/step - loss: 1.1660 - sparse_categorical_accuracy: 0.6640\n",
            "Epoch 2/5\n",
            "32/32 [==============================] - 0s 9ms/step - loss: 0.4405 - sparse_categorical_accuracy: 0.8720\n",
            "Epoch 3/5\n",
            "32/32 [==============================] - 0s 7ms/step - loss: 0.2807 - sparse_categorical_accuracy: 0.9290\n",
            "Epoch 4/5\n",
            "32/32 [==============================] - 0s 8ms/step - loss: 0.2046 - sparse_categorical_accuracy: 0.9520\n",
            "Epoch 5/5\n",
            "32/32 [==============================] - 0s 8ms/step - loss: 0.1511 - sparse_categorical_accuracy: 0.9700\n"
          ]
        }
      ],
      "source": [
        "# Create and train a new model instance.\n",
        "model = create_model()\n",
        "model.fit(train_images, train_labels, epochs=5)\n",
        "\n",
        "# Save the entire model as a `.keras` zip archive.\n",
        "model.save('my_model.keras')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iHqwaun5g8lD"
      },
      "source": [
        "从 `.keras` zip 归档重新加载新的 Keras 模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "HyfUMOZwux_-",
        "outputId": "1965e220-7395-46a9-b191-2108cbc90f20",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"sequential_5\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " dense_10 (Dense)            (None, 512)               401920    \n",
            "                                                                 \n",
            " dropout_5 (Dropout)         (None, 512)               0         \n",
            "                                                                 \n",
            " dense_11 (Dense)            (None, 10)                5130      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 407050 (1.55 MB)\n",
            "Trainable params: 407050 (1.55 MB)\n",
            "Non-trainable params: 0 (0.00 Byte)\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "new_model = tf.keras.models.load_model('my_model.keras')\n",
        "\n",
        "# Show the model architecture\n",
        "new_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Cn3pSBqvJ5f"
      },
      "source": [
        "尝试使用加载的模型运行评估和预测："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "8BT4mHNIvMdW",
        "outputId": "cc6dd5f1-deeb-4cf0-c575-4684a4e600ab",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.4373 - sparse_categorical_accuracy: 0.8620 - 198ms/epoch - 6ms/step\n",
            "Restored model, accuracy: 86.20%\n",
            "32/32 [==============================] - 0s 2ms/step\n",
            "(1000, 10)\n"
          ]
        }
      ],
      "source": [
        "# Evaluate the restored model\n",
        "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
        "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n",
        "\n",
        "print(new_model.predict(test_images).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kPyhgcoVzqUB"
      },
      "source": [
        "### SavedModel 格式"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LtcN4VIb7JkK"
      },
      "source": [
        "SavedModel 格式是另一种序列化模型的方式。以这种格式保存的模型可以使用 `tf.keras.models.load_model` 还原，并且与 TensorFlow Serving 兼容。[SavedModel 指南](../../guide/saved_model.ipynb)详细介绍了如何 `serve/inspect` SavedModel。以下部分说明了保存和恢复模型的步骤。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "sI1YvCDFzpl3",
        "outputId": "17250d39-dbb3-4fc2-8e8e-92f493385f64",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/5\n",
            "32/32 [==============================] - 1s 6ms/step - loss: 1.1855 - sparse_categorical_accuracy: 0.6720\n",
            "Epoch 2/5\n",
            "32/32 [==============================] - 0s 6ms/step - loss: 0.4246 - sparse_categorical_accuracy: 0.8820\n",
            "Epoch 3/5\n",
            "32/32 [==============================] - 0s 7ms/step - loss: 0.2714 - sparse_categorical_accuracy: 0.9330\n",
            "Epoch 4/5\n",
            "32/32 [==============================] - 0s 8ms/step - loss: 0.2127 - sparse_categorical_accuracy: 0.9440\n",
            "Epoch 5/5\n",
            "32/32 [==============================] - 0s 8ms/step - loss: 0.1550 - sparse_categorical_accuracy: 0.9670\n"
          ]
        }
      ],
      "source": [
        "# Create and train a new model instance.\n",
        "model = create_model()\n",
        "model.fit(train_images, train_labels, epochs=5)\n",
        "\n",
        "# Save the entire model as a SavedModel.\n",
        "!mkdir -p saved_model\n",
        "model.save('saved_model/my_model')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iUvT_3qE8hV5"
      },
      "source": [
        "SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录。检查保存的模型目录："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "id": "sq8fPglI1RWA",
        "outputId": "6400baeb-2b4b-40cb-cbc5-fd4b8c0016cd",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "my_model\n",
            "assets\tfingerprint.pb\tkeras_metadata.pb  saved_model.pb  variables\n"
          ]
        }
      ],
      "source": [
        "# my_model directory\n",
        "!ls saved_model\n",
        "\n",
        "# Contains an assets folder, saved_model.pb, and variables folder.\n",
        "!ls saved_model/my_model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "B7qfpvpY9HCe"
      },
      "source": [
        "从保存的模型重新加载一个新的 Keras 模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "0YofwHdN0pxa",
        "outputId": "f59e8e5c-ef22-493d-b9e4-1f6e987dbdf3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"sequential_6\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " dense_12 (Dense)            (None, 512)               401920    \n",
            "                                                                 \n",
            " dropout_6 (Dropout)         (None, 512)               0         \n",
            "                                                                 \n",
            " dense_13 (Dense)            (None, 10)                5130      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 407050 (1.55 MB)\n",
            "Trainable params: 407050 (1.55 MB)\n",
            "Non-trainable params: 0 (0.00 Byte)\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "new_model = tf.keras.models.load_model('saved_model/my_model')\n",
        "\n",
        "# Check its architecture\n",
        "new_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uWwgNaz19TH2"
      },
      "source": [
        "使用与原始模型相同的实参编译恢复的模型。尝试使用加载的模型运行评估和预测："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "Yh5Mu0yOgE5J",
        "outputId": "ddac27b7-10a7-4a9d-a9f1-8f0e01ce493d",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n",
            "WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.1\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.2\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.3\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.4\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.5\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.6\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.7\n",
            "WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer._variables.8\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.4158 - sparse_categorical_accuracy: 0.8640 - 233ms/epoch - 7ms/step\n",
            "Restored model, accuracy: 86.40%\n",
            "32/32 [==============================] - 0s 3ms/step\n",
            "(1000, 10)\n"
          ]
        }
      ],
      "source": [
        "# Evaluate the restored model\n",
        "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
        "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))\n",
        "\n",
        "print(new_model.predict(test_images).shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SkGwf-50zLNn"
      },
      "source": [
        "### HDF5 格式\n",
        "\n",
        "Keras 使用 [HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) 标准提供基本的旧版高级保存格式。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "m2dkmJVCGUia",
        "outputId": "38c761ce-dce2-445f-9300-2b97b4307bb8",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/5\n",
            "32/32 [==============================] - 1s 6ms/step - loss: 1.1928 - sparse_categorical_accuracy: 0.6460\n",
            "Epoch 2/5\n",
            "32/32 [==============================] - 0s 6ms/step - loss: 0.4275 - sparse_categorical_accuracy: 0.8760\n",
            "Epoch 3/5\n",
            "32/32 [==============================] - 0s 6ms/step - loss: 0.2840 - sparse_categorical_accuracy: 0.9300\n",
            "Epoch 4/5\n",
            "32/32 [==============================] - 0s 6ms/step - loss: 0.2118 - sparse_categorical_accuracy: 0.9520\n",
            "Epoch 5/5\n",
            "32/32 [==============================] - 0s 7ms/step - loss: 0.1546 - sparse_categorical_accuracy: 0.9660\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n",
            "  saving_api.save_model(\n"
          ]
        }
      ],
      "source": [
        "# Create and train a new model instance.\n",
        "model = create_model()\n",
        "model.fit(train_images, train_labels, epochs=5)\n",
        "\n",
        "# Save the entire model to a HDF5 file.\n",
        "# The '.h5' extension indicates that the model should be saved to HDF5.\n",
        "model.save('my_model.h5')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GWmttMOqS68S"
      },
      "source": [
        "现在，从该文件重新创建模型："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "5NDMO_7kS6Do",
        "outputId": "3b05484d-e809-41de-8c03-54b4b7fd1fa4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model: \"sequential_7\"\n",
            "_________________________________________________________________\n",
            " Layer (type)                Output Shape              Param #   \n",
            "=================================================================\n",
            " dense_14 (Dense)            (None, 512)               401920    \n",
            "                                                                 \n",
            " dropout_7 (Dropout)         (None, 512)               0         \n",
            "                                                                 \n",
            " dense_15 (Dense)            (None, 10)                5130      \n",
            "                                                                 \n",
            "=================================================================\n",
            "Total params: 407050 (1.55 MB)\n",
            "Trainable params: 407050 (1.55 MB)\n",
            "Non-trainable params: 0 (0.00 Byte)\n",
            "_________________________________________________________________\n"
          ]
        }
      ],
      "source": [
        "# Recreate the exact same model, including its weights and the optimizer\n",
        "new_model = tf.keras.models.load_model('my_model.h5')\n",
        "\n",
        "# Show the model architecture\n",
        "new_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JXQpbTicTBwt"
      },
      "source": [
        "检查其准确率（accuracy）："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "jwEaj9DnTCVA",
        "outputId": "5852af17-4003-496e-a50c-752189bed2fe",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "32/32 - 0s - loss: 0.4356 - sparse_categorical_accuracy: 0.8600 - 179ms/epoch - 6ms/step\n",
            "Restored model, accuracy: 86.00%\n"
          ]
        }
      ],
      "source": [
        "loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)\n",
        "print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dGXqd4wWJl8O"
      },
      "source": [
        "Keras 通过检查模型的架构来保存这些模型。这种技术可以保存所有内容：\n",
        "\n",
        "- 权重值\n",
        "- 模型的架构\n",
        "- 模型的训练配置（您传递给 `.compile()` 方法的内容）\n",
        "- 优化器及其状态（如果有）（这样，您便可从中断的地方重新启动训练）\n",
        "\n",
        "Keras 无法保存 `v1.x` 优化器（来自 `tf.compat.v1.train`），因为它们与检查点不兼容。对于 v1.x 优化器，您需要在加载-失去优化器的状态后，重新编译模型。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kAUKJQyGqTNH"
      },
      "source": [
        "### 保存自定义对象\n",
        "\n",
        "如果您使用的是 SavedModel 格式，则可以跳过此部分。高级 `.keras`/HDF5 格式与低级 SavedModel 格式之间的主要区别在于 `.keras`/HDF5 格式使用对象配置来保存模型架构，而 SavedModel 保存执行计算图。因此，SavedModels 能够保存自定义对象，例如子类化模型和自定义层，而无需原始代码。但是，因此调试低级 SavedModels 可能会更加困难，鉴于基于名称并且对于 Keras 是原生的特性，我们建议改用高级 `.keras` 格式。\n",
        "\n",
        "要将自定义对象保存到 `.keras` 和 HDF5，您必须执行以下操作：\n",
        "\n",
        "1. 在您的对象中定义一个 `get_config` 方法，并且可以选择定义一个 `from_config` 类方法。\n",
        "    - `get_config(self)` 返回重新创建对象所需的形参的 JSON 可序列化字典。\n",
        "    - `from_config(cls, config)` 使用从 `get_config` 返回的配置来创建一个新对象。默认情况下，此函数将使用配置作为初始化 kwarg (`return cls(**config)`)。\n",
        "2. 通过以下三种方式之一将自定义对象传递给模型：\n",
        "    - 使用 `@tf.keras.utils.register_keras_serializable` 装饰器注册自定义对象。**（推荐）**\n",
        "    - 加载模型时直接将对象传递给 `custom_objects` 实参。实参必须是将字符串类名映射到 Python 类的字典。例如 `tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})`\n",
        "    - 将 `tf.keras.utils.custom_object_scope` 与 `custom_objects` 字典实参中包含的对象一起使用，并在作用域内放置一个 `tf.keras.models.load_model(path){ /code2} 调用。`\n",
        "\n",
        "有关自定义对象和 `get_config` 的示例，请参阅[从头开始编写层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)教程。\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "save_and_load.ipynb",
      "toc_visible": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}